/*
 * jl_recovery_sketch.cpp
 * Logistic regression using sparse JL sketch.
 * Recovery done by multiplying by R^T where
 * R is a sparse JL matrix.
 */

#include "jl_recovery_sketch.h"
#include <iostream>
#include <numeric>
#include "util.h"
#include <math.h>

namespace wmsketch{

JLRecoverySketch::JLRecoverySketch(
		uint32_t log2_width,
		uint32_t depth,
		int32_t seed,
		float lr_init,
		float l2_reg)
  : bias_{0.f},
	lr_init_{lr_init},
	l2_reg_{l2_reg},
	scale_{1.f},
	t_{0},
	depth_{depth},
	hash_fn_(depth, seed),
	hash_buf_(depth, 0) {

	if (log2_width > JLRecoverySketch::MAX_LOG2_WIDTH) {
		throw std::invalid_argument("Invalid sketch width");
	}

	if (lr_init <= 0.) {
		throw std::invalid_argument("Initial learning rate must be positive");
	}

	uint32_t width = 1 << log2_width;
	width_mask_ = width - 1;

	weights_ = (float**) calloc(depth, sizeof(float*));
	weights_[0] = (float*) calloc(width * depth, sizeof(float));
	for (int i = 0; i < depth; i++) {
		weights_[i] = weights_[0] + i * width;
	}
}

JLRecoverySketch::~JLRecoverySketch() {
	free(weights_[0]);
	free(weights_);
}

float JLRecoverySketch::get(uint32_t key) {
	return scale_ * get_weight(key);
}

/*
 * Get the weights corresponding to the nonzero
 * coordinates in the sparse vector x. Weights
 * are stored in weight_sums_.
 */
void JLRecoverySketch::get_weights(const std::vector<std::pair<uint32_t, float>>& x) {
	uint64_t n = x.size();
	if (hash_buf_.size() < depth_ * n) {
		hash_buf_.resize(depth_ * n);
	}

	weight_sums_.resize(n);
	uint32_t* ph = hash_buf_.data();

	for (int idx = 0; idx < n; idx++) {
		/*
		 * Go through the nonzero entries of R_{idx},
		 * and take the dot product of it with weights.
		 *
		 * Note that x[idx].first, NOT idx, is the feature
		 * that the coordinate x[idx] corresponds to.
		 */
		hash_fn_.hash(ph + idx * depth_, x[idx].first);
		float dot_product = 0;
		for (int i = 0; i < depth_; i++) {
			uint32_t h = hash_buf_[idx * depth_ + i];
			int sgn = (h >> 31) ? +1 : -1;
			float jl_entry = sgn / sqrt(depth_);
			dot_product += jl_entry * weights_[i][h & width_mask_];
		}
		weight_sums_[idx] = dot_product;
	}
}

/*
 * Return z^TRx, where x is the sparse training example,
 * R is the sparse JL matrix, and z is the weight vector.
 */
float JLRecoverySketch::dot(const std::vector<std::pair<uint32_t, float>>& x) {
	if (x.size() == 0) return 0.f;
	get_weights(x);
	float dot_product = 0.f;
	for (int idx = 0; idx < x.size(); idx++) {
		// Note that the following code adds x_{idx} * weight_sums_[idx]
		// to dot_product. This is actually equal to x_{idx} * <sketched_weights, R_{idx}>.
		// Thus dot_product will equal the desired result.
		float val = x[idx].second;
		dot_product += val * weight_sums_[idx];
	}
	dot_product *= scale_;
	return dot_product;
}

bool JLRecoverySketch::predict(const std::vector<std::pair<uint32_t, float>>& x) {
	float z = dot(x) + bias_;
	return z >= 0.;
}

/*
 * Estimates the weight corresponding to feature index key.
 * Takes the dot product of weight vector and R_key, where
 * R is the sparse JL matrix.
 */
float JLRecoverySketch::get_weight(uint32_t key) {
	hash_fn_.hash(hash_buf_.data(), key);
	float dot_product = 0;
	for (int i = 0; i < depth_; i++) {
		uint32_t h = hash_buf_[i];
		int sgn = (h >> 31) ? +1 : -1;
		float jl_entry = sgn / sqrt(depth_);
		dot_product += jl_entry * weights_[i][h & width_mask_];
	}
	return dot_product;
}

float JLRecoverySketch::bias() {
	return bias_;
}

float JLRecoverySketch::scale() {
	return scale_;
}

/*
 * Updates weights and return true if our
 * prediction is 1, and false otherwise.
 *
 * Gradient update is same as logistic_sketch,
 * though we scale properly by 1/sqrt(depth).
 */
bool JLRecoverySketch::update(const std::vector<std::pair<uint32_t, float>>& x, bool label) {
	if (x.size() == 0) {
		return bias_ >= 0;
	}
	int y = label ? +1 : -1;
	float lr = lr_init_ / (1.f + lr_init_ * l2_reg_ * t_);
	float z = dot(x) + bias_;
	float g = logistic_grad(y * z);
	scale_ *= (1 - lr * l2_reg_);
	float u = lr * y * g / scale_;

	for (int idx = 0; idx < x.size(); idx++) {
		// Update the weights corresponding to
		// nonzero entries of R_{idx}.
		float val = x[idx].second;
		for (int i = 0; i < depth_; i++) {
			uint32_t h = hash_buf_[idx * depth_ + i];
			int sgn = (h >> 31) ? +1 : -1;
			weights_[i][h & width_mask_] -= (sgn / sqrt(depth_)) * u * val;
		}
	}

	bias_ -= lr * y * g;
	t_++;
	return z >= 0;
}

bool JLRecoverySketch::update(
	std::vector<float>& new_weights,
	const std::vector<std::pair<uint32_t, float>>& x,
	bool label) {

	uint64_t n = x.size();
	new_weights.resize(n);
	if (n == 0) {
		return bias_ >= 0;
	}

	int y = label ? +1 : -1;
	float lr = lr_init_ / (1.f + lr_init_ * l2_reg_ * t_);
	float z = dot(x) + bias_;
	float g = logistic_grad(y * z);
	scale_ *= (1 - lr * l2_reg_);
	float u = lr * y * g / scale_;

	for (int idx = 0; idx < n; idx++) {
		float val = x[idx].second;
		for (int i = 0; i < depth_; i++) {
			uint32_t h = hash_buf_[idx * depth_ + i];
			int sgn = (h >> 31) ? +1 : -1;
			weights_[i][h & width_mask_] -= (sgn / sqrt(depth_)) * u * val;
		}
		new_weights[idx] = weight_sums_[idx] - u * val;
	}

	bias_ -= lr * y * g;
	t_++;
	return z >= 0;
}


/*
 * Update if the example is e_i for some i = key.
 * Not used or implemented.
 */
bool JLRecoverySketch::update(uint32_t key, bool label) {
	throw std::logic_error("Not implemented.");
	return false;
}

}
